Skip to content

feat(dpa4): use edge force and atomic virial#5518

Merged
OutisLi merged 3 commits into
deepmodeling:masterfrom
OutisLi:pr/edgefv
Jun 15, 2026
Merged

feat(dpa4): use edge force and atomic virial#5518
OutisLi merged 3 commits into
deepmodeling:masterfrom
OutisLi:pr/edgefv

Conversation

@OutisLi

@OutisLi OutisLi commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Summary by CodeRabbit

  • New Features

    • Added per-atom virial computation function for improved force and virial calculations.
  • Changes

    • Atomic virial is now enabled by default when exporting models, providing extended spatial derivative information without additional cost.
    • Refactored force and virial computation pipeline for improved accuracy and efficiency.
  • Tests

    • Expanded test coverage for edge-based force assembly and virtual-type masking validation.

Copilot AI review requested due to automatic review settings June 12, 2026 09:51
@dosubot dosubot Bot added the new feature label Jun 12, 2026

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the PyTorch SeZM/DPA4 implementation to compute forces and virials via an edge-based gradient (“edge-force scatter”) rather than differentiating w.r.t. coordinates, and adjusts the surrounding compile/export/test infrastructure accordingly. This aligns the SeZM path with an edge-centric internal representation and adds validation tests for the new force/virial assembly.

Changes:

  • Added edge_energy_deriv() to assemble extended force, global virial, and per-atom virial by scattering gradients taken w.r.t. per-edge displacement vectors.
  • Updated SeZMModel to build both local and extended edge index spaces, detach edge vectors as the autograd leaf, and route ZBL bridging through the edge-form InterPotential.
  • Expanded/updated PT tests to use edge-based InterPotential inputs and to finite-difference validate force/virial/atom-virial consistency; adjusted compile-cache key expectations.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.

Show a summary per file
File Description
source/tests/pt/model/test_sezm_spin_model.py Updates bridging-mask and compile-cache assertions for new edge-based API and cache key shape.
source/tests/pt/model/test_sezm_model.py Migrates InterPotential tests to edge inputs and adds finite-difference validation for edge-force/virial/atomic-virial assembly.
deepmd/pt/model/model/transform_output.py Adds edge_energy_deriv() to compute force/virial/atomic-virial via edge gradients and explicit scatters.
deepmd/pt/model/model/sezm_model.py Switches SeZM force/virial computation to edge-force scatter; updates edge-list builder to return extended indices; converts ZBL to edge form; simplifies compile-cache key.
deepmd/pt/model/descriptor/sezm_nn/utils.py Avoids creating a scalar tensor for eps_sq in safe_norm.
deepmd/pt/model/descriptor/sezm_nn/edge_cache.py Reworks masked-edge canonical-direction padding using F.pad; removes a small scalar tensor allocation.
deepmd/pt/entrypoints/freeze_pt2.py Changes SeZM .pt2 freeze default to export atomic virial and updates docstring accordingly.
Comments suppressed due to low confidence (1)

deepmd/pt/entrypoints/freeze_pt2.py:484

  • Changing freeze_sezm_to_pt2() default atomic_virial from False to True changes the default .pt2 output contract (extra per-atom virial output keys + metadata do_atomic_virial=true) for all callers that don’t pass the flag, including deepmd.pt.entrypoints.main.freeze(). This is a behavior/API change that can break downstream consumers expecting the previous default output set; consider keeping the default False and letting callers opt in explicitly (or plumb a CLI option).
    atomic_virial: bool = True,
) -> None:
    """Freeze a SeZM checkpoint into an AOTInductor ``.pt2`` archive.

    Parameters

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderabbitai

coderabbitai Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

This PR refactors the SeZM model's force and virial computation pipeline by shifting the differentiation anchor from extended coordinates to an edge-vector autograd leaf, introduces a new edge_energy_deriv function for extended-space derivatives, simplifies compile-cache keying, updates the analytical potential bridge to edge-sparse inputs, and flips the freeze_sezm_to_pt2 default for atomic virial to True.

Changes

SeZM edge-based force/virial refactoring

Layer / File(s) Summary
New edge_energy_deriv function for extended-space derivatives
deepmd/pt/model/model/transform_output.py
edge_energy_deriv computes extended-space force and virial tensors from per-edge energy gradients, handling masking, indexed accumulation, per-atom virial outer products, and optional coordinate corrections.
core_compute refactoring to use edge-vector differentiation
deepmd/pt/model/model/sezm_model.py
core_compute builds edges with explicit extended indices, detaches edge_vec as autograd leaf, calls edge_energy_deriv for force/virial instead of coordinate rebinding, and wires analytical pair potentials to edge-sparse inputs. Documentation updated to describe edge-leaf strategy.
build_edge_list_from_nlist extended index output
deepmd/pt/model/model/sezm_model.py
build_edge_list_from_nlist now computes and returns extended edge indices (edge_index_ext) and refactors edge-vector construction using torch.gather with explicit padding-safe indexing.
InterPotential rewired to edge-sparse inputs
deepmd/pt/model/model/sezm_model.py
InterPotential.forward signature and implementation changed from extended-coordinate/neighbor-list to edge-sparse inputs (edge_vec, edge_index, atype_flat, edge_mask), with masking to exclude virtual/spin nodes.
Compile cache key simplification and trace_and_compile updates
deepmd/pt/model/model/sezm_model.py
Compile cache key reduced from 3-element to 2-element tuple (training, has_coord_corr), removing do_atomic_virial parameter from trace_and_compile signature, and updating cache slot topology and logging.
trace_and_compile coordinate preparation and forward flow
deepmd/pt/model/model/sezm_model.py
Coordinate preparation simplified to detach without conditional grad re-enabling, traced compute_fn calls updated to match new core_compute signature, and both compiled and non-compiled paths aligned with new cache keying.
Lower-interface coordinate detach and export tracing
deepmd/pt/model/model/sezm_model.py
forward_common_lower and forward_common_lower_exportable updated to remove coordinate grad-endpoint rebuilding and move detach inside traced closures to decouple export graphs from upstream LAMMPS gradient sources.
Documentation and NOTE catalogue updates
deepmd/pt/model/model/sezm_model.py
Docstring and NOTE sections rewritten to describe force-derivative flow anchored at edge_vec leaf, revised compile-cache topology, trace-time detach strategy, and inference-path rationale.
Edge cache and safe_norm utility optimizations
deepmd/pt/model/descriptor/sezm_nn/edge_cache.py, deepmd/pt/model/descriptor/sezm_nn/utils.py
Masked-edge canonicalization optimized using F.pad instead of explicit vector addition, inverse-sqrt degree simplified by removing intermediate floor tensor, and safe_norm epsilon changed from tensor to plain float representation.
InterPotential unit tests refactored to edge-sparse signature
source/tests/pt/model/test_sezm_model.py
Introduces _pair_edges helper for directed edge tensor construction, refactors O–O and O–H ZBL known-value tests to use edge-based call signature, and adds gradient and virtual-spin-type masking tests.
New edge-force scatter validation test suite
source/tests/pt/model/test_sezm_model.py
Adds TestSeZMEdgeForceScatter class with finite-difference validation for edge-force scatter correctness (periodic and non-periodic) and atom-virial reduction under descriptor-only and ZBL bridging configurations.
Test compile-cache key updates and edge-list unpacking
source/tests/pt/model/test_sezm_model.py
Updates compile-cache slot assertions to reflect new 2-element key tuples, and adjusts build_edge_list_from_nlist unpacking to handle additional edge_index_ext return value.
Spin model test updates for edge-based masking and cache keys
source/tests/pt/model/test_sezm_spin_model.py
Rewrites test_bridging_masks_virtual_pairs to use explicit edge-graph tensors with virtual node and validates masking by comparing against real-only edge subset; updates compile cache key assertion.

Freeze PT2 atomic_virial default flip

Layer / File(s) Summary
freeze_sezm_to_pt2 atomic_virial default and docstring
deepmd/pt/entrypoints/freeze_pt2.py
atomic_virial parameter default flipped from False to True, with updated docstring describing that per-atom virial export is now default and is a zero-cost by-product.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5483: Both PRs modify deepmd/pt/model/model/sezm_model.py's trace_and_compile/export compilation path, overlapping directly on compile-cache sharing and trace-time handling.
  • deepmodeling/deepmd-kit#5407: Both PRs directly change how atomic_virial is set and consumed during .pt2 export; this PR flips the freeze_sezm_to_pt2 default to True while the other adds atomic_virial flag threading through convert_backend and runtime metadata.

Suggested labels

new feature, Python, Core

Suggested reviewers

  • anyangml
  • njzjz
  • iProzd
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 78.26% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: introducing edge-based force computation and atomic virial functionality throughout the SeZM model pipeline.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/model/model/sezm_model.py (1)

2410-2413: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Keep coincident edges in the sparse edge list.

len_positive = edge_len2 > 1e-10 changes the runtime model, not just the trace sample. The eager edge-cache path keeps valid edges at r == 0 and relies on safe_norm / clamp logic downstream to stay finite, but this filter drops them entirely. That makes the sparse SeZM path disagree with the eager path on overlapping configurations and also suppresses the ZBL bridge exactly where the short-range repulsion should be strongest.

Suggested fix
-        len_positive = edge_len2 > 1e-10
-        edge_mask_actual = valid_flat & src_local_valid & len_positive
+        edge_mask_actual = valid_flat & src_local_valid

If the trace-only clamped sample still needs self-edge sanitization, handle that in the trace-input preparation instead of changing the runtime edge semantics.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/sezm_model.py` around lines 2410 - 2413, The current
filter removes coincident edges by using len_positive = edge_len2 > 1e-10 and
applying it to edge_mask_actual, which changes runtime semantics and breaks
parity with the eager path; remove the len_positive check from the runtime
edge_mask_actual (keep valid_flat & src_local_valid) so edges with r==0 remain
in the sparse SeZM edge list, and if trace-only sanitization is required,
perform clamping/sanitization in the trace-input preparation logic instead of in
the runtime filter; update references around src_local_valid, edge_len2,
edge_mask_actual and ensure downstream safe_norm/clamp logic handles
zero-distance cases as before.
🧹 Nitpick comments (1)
source/tests/pt/model/test_sezm_model.py (1)

1218-1230: ⚡ Quick win

Add a direct ZBL virial finite-difference check.

The bridging_method="ZBL" branch is only validated by atom_virial.sum(dim=1) == virial. If the bridged virial is scattered with the same sign/indexing bug into both tensors, this still passes. Reusing the strain finite-difference check for ZBL would pin the new bridged virial path itself.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/pt/model/test_sezm_model.py` around lines 1218 - 1230, In
test_atom_virial_sums_to_global_virial, add a direct finite-difference check for
the "ZBL" branch rather than only asserting atom_virial.sum == virial; after
building the model with _build_model(bridging_method="ZBL") and calling
model(coord, atype, box=box, do_atomic_virial=True), compute the global virial
via the existing numerical finite-difference helper used for strain checks
(reuse the same FD routine used elsewhere in the test suite) and assert that the
model's out["virial"] matches that FD-computed virial within tolerances — this
pins the bridged virial path itself and prevents sign/index scatter from passing
the test.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 2410-2413: The current filter removes coincident edges by using
len_positive = edge_len2 > 1e-10 and applying it to edge_mask_actual, which
changes runtime semantics and breaks parity with the eager path; remove the
len_positive check from the runtime edge_mask_actual (keep valid_flat &
src_local_valid) so edges with r==0 remain in the sparse SeZM edge list, and if
trace-only sanitization is required, perform clamping/sanitization in the
trace-input preparation logic instead of in the runtime filter; update
references around src_local_valid, edge_len2, edge_mask_actual and ensure
downstream safe_norm/clamp logic handles zero-distance cases as before.

---

Nitpick comments:
In `@source/tests/pt/model/test_sezm_model.py`:
- Around line 1218-1230: In test_atom_virial_sums_to_global_virial, add a direct
finite-difference check for the "ZBL" branch rather than only asserting
atom_virial.sum == virial; after building the model with
_build_model(bridging_method="ZBL") and calling model(coord, atype, box=box,
do_atomic_virial=True), compute the global virial via the existing numerical
finite-difference helper used for strain checks (reuse the same FD routine used
elsewhere in the test suite) and assert that the model's out["virial"] matches
that FD-computed virial within tolerances — this pins the bridged virial path
itself and prevents sign/index scatter from passing the test.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: c91a1fa5-678f-433e-8e55-1d20f67ae5b2

📥 Commits

Reviewing files that changed from the base of the PR and between 0de53e9 and a5cea0f.

📒 Files selected for processing (7)
  • deepmd/pt/entrypoints/freeze_pt2.py
  • deepmd/pt/model/descriptor/sezm_nn/edge_cache.py
  • deepmd/pt/model/descriptor/sezm_nn/utils.py
  • deepmd/pt/model/model/sezm_model.py
  • deepmd/pt/model/model/transform_output.py
  • source/tests/pt/model/test_sezm_model.py
  • source/tests/pt/model/test_sezm_spin_model.py

@codecov

codecov Bot commented Jun 12, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.19%. Comparing base (890e38a) to head (a5cea0f).
⚠️ Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5518      +/-   ##
==========================================
+ Coverage   81.52%   82.19%   +0.67%     
==========================================
  Files         872      891      +19     
  Lines       97964   101600    +3636     
  Branches     4241     4242       +1     
==========================================
+ Hits        79865    83511    +3646     
+ Misses      16795    16786       -9     
+ Partials     1304     1303       -1     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@OutisLi OutisLi added this pull request to the merge queue Jun 14, 2026
Merged via the queue into deepmodeling:master with commit 16df629 Jun 15, 2026
73 checks passed
@OutisLi OutisLi deleted the pr/edgefv branch June 15, 2026 14:57
njzjz pushed a commit to njzjz/deepmd-kit that referenced this pull request Jun 17, 2026
…eepmodeling#5540)

PR-3 (final) of the DPA4/SeZM porting series — pt_expt **inference**:
freeze to `.pt2`, Python `DeepEval`, pt→pt_expt checkpoint interop, C++
single-rank, and LAMMPS single-rank. Follows PR-1 (deepmodeling#5515, dpmodel core)
and PR-2 (deepmodeling#5522, pt_expt training/export).

## What's included

- **Model freeze to `.pt2`** (`deepmd/pt_expt/model/ener_model.py`,
`deepmd/dpmodel/.../ener_model.py`,
`deepmd/dpmodel/descriptor/dpa4.py`): register `EnergyModel` under
`sezm_ener`/`dpa4_ener` model-type aliases so `BaseModel.deserialize`
resolves a standard DPA4 energy model (whose fitting type is
`sezm_ener`). Fixed a `torch.export` specialization where `int()` on
symbolic shapes baked `nf*nloc` (embedding/so2/attention).
- **NoPBC export fix** (`deepmd/dpmodel/descriptor/dpa4.py`): the
`atype_ext[:, :nloc]` slice emitted a spurious `Ne(nall, nloc)` shape
guard that crashed the compiled artifact when `nall==nloc` (no ghosts);
replaced with `xp_take_first_n` (index_select). NoPBC now matches PBC.
- **pt→pt_expt checkpoint interop** (`deepmd/pt_expt/model/model.py`):
`BaseModel.deserialize` unwraps pt's bespoke `SeZMModel` serialization
(`type:"SeZM"`, nested `sezm_atomic` atomic model with the pt-only dens
head), validates versions, rejects unsupported features
(bridging/lora/dens/active_mode) with `NotImplementedError`, and
delegates to the standard path.
- **Warn on silently-ignored flags** (`use_amp` descriptor,
`enable_tf32` model): warn-once instead of silent drop.

## Tests

- **Model freeze** `source/tests/pt_expt/model/test_dpa4_export.py`:
dual-artifact `.pt2`, metadata, AOTI load, artifact-vs-eager parity
(1e-10). *(CI-skipped — AOTI is slow; run locally.)*
- **DeepEval parity vs pt**
`source/tests/pt_expt/infer/test_dpa4_deep_eval.py`: pt `.pt` vs pt_expt
`.pt2` energy/force/global-virial/atom-energy at fp64 1e-10, **PBC and
NoPBC**; doubles as the checkpoint-interop proof. Per-atom virial
compared by sum (pt's edge-scatter from deepmodeling#5518 redistributes it; global
virial matches). *(CI-skipped — AOTI.)*
- **Interop unit tests**
`source/tests/pt_expt/model/test_dpa4_interop.py` (CI-runnable, no
AOTI): happy-path pt-checkpoint→pt_expt round-trip + every guard branch
+ version validation + `@variables` filtering.
- **Alias deserialize guard** + **use_amp/enable_tf32 warn-once** tests
(CI-runnable).
- **Fixture generator** `source/tests/infer/gen_dpa4.py` (+ wired into
`source/install/test_cc_local.sh`).
- **C++ single-rank** `source/api_cc/tests/test_deeppot_dpa4_ptexpt.cc`:
20 tests (double+float), dpa3-matched tolerances. Validated locally.
- **LAMMPS single-rank** `source/lmp/tests/test_lammps_dpa4_pt2.py`:
parity + `atom_modify map yes` + the deepmodeling#5450 no-atom-map fail-fast.
**Validated on a GPU box (7 passed).**

PR-1 parity suites stay green; the small dpmodel edits are
parity-revalidated.

## Known limitations

- **Single-rank only.** Multi-rank/MPI LAMMPS for DPA4 is deferred (no
live multi-rank cell; the with-comm artifact compiles but its runtime is
not exercised). DPA4 is a message-passing descriptor, so multi-rank
follows the existing deepmodeling#5450/deepmodeling#5430 machinery in a later PR.
- **No `.pth` (torch.jit) DPA4** — the pt backend has no `sezm_ener`
*model* registration, so `.pth` freeze of a standard DPA4 energy model
isn't available; not needed for the pt_expt inference path.
- **Per-atom virial** is not compared element-wise pt-vs-pt_expt (only
its global sum) — deepmodeling#5518 changed pt's edge-scatter distribution; both are
correct, the distribution differs.
- **AOTI tests are CI-skipped** (multi-minute compile) — the
freeze/DeepEval paths are validated locally, not in CI; the
interop/alias/warn tests give CI coverage of the non-AOTI logic.
- **fp64 only**; fp32 freeze untested. CUDA validated at LAMMPS level on
a GPU box; the AOTI parity numbers are from CPU fp64.
- **`use_amp`/`enable_tf32`** remain functionally ignored (now warned) —
by design for this series.
- pt SeZM features out of scope (guarded `NotImplementedError`): spin,
ZBL bridging, LoRA, dens/direct-force/denoising heads, SO3 grid
projection, GridMLP, SO(2) attention extensions.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

# Release Notes

* **New Features**
* Enabled DPA4 model inference via the pt_expt backend using
dual-artifact compilation.
* Registered the EnergyModel under additional aliases: `sezm_ener` and
`dpa4_ener`.

* **Improvements**
* Improved dynamic/symbolic shape handling across DPA4 components for
export/tracing stability.
* Enhanced pt SeZM/DPA4 checkpoint deserialization and normalization for
interoperability.
* Added one-time warnings when `use_amp` or `enable_tf32` settings are
ineffective.

* **Tests**
* Added C++ and Python coverage for pt2 inference, LAMMPS integration,
model export/freeze, parity, interop, and warning behavior.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants